"""Memory for agent.

Human memory follows a general progression from sensory memory that registers
perceptual inputs, to short-term memory that maintains information transiently, to
long-term memory that consolidates information over extended periods.
"""

import asyncio
from abc import ABC, abstractmethod
from datetime import datetime
from enum import Enum
from typing import (
    Any,
    Callable,
    Dict,
    Generic,
    List,
    Optional,
    Tuple,
    Type,
    TypeVar,
    Union,
    cast,
)

from dbgpt.core import LLMClient
from dbgpt.util.annotations import PublicAPI, immutable, mutable

T = TypeVar("T", bound="MemoryFragment")
M = TypeVar("M", bound="Memory")


class WriteOperation(str, Enum):
    """Write operation."""

    ADD = "add"
    RETRIEVAL = "retrieval"


@PublicAPI(stability="beta")
class MemoryFragment(ABC):
    """Memory fragment interface.

    It is the interface of memory fragment, which is the basic unit of memory, which
    contains the basic information of memory, such as observation, importance, whether
    it is insight, last access time, etc
    """

    @classmethod
    @abstractmethod
    def build_from(
        cls: Type[T],
        observation: str,
        embeddings: Optional[List[float]] = None,
        memory_id: Optional[int] = None,
        importance: Optional[float] = None,
        is_insight: bool = False,
        last_accessed_time: Optional[datetime] = None,
        **kwargs,
    ) -> T:
        """Build a memory fragment from memory id and observation.

        Args:
            observation(str): Observation
            embeddings(List[float], optional): Embeddings of the memory fragment.
            memory_id(int): Memory id
            importance(float): Importance
            is_insight(bool): Whether the memory fragment is an insight
            last_accessed_time(datetime): Last accessed time

        Returns:
            MemoryFragment: Memory fragment
        """
        raise NotImplementedError

    @property
    @abstractmethod
    def id(self) -> int:
        """Return the id of the memory fragment.

        Commonly, the id is generated by Snowflake algorithm. So we can parse the
        timestamp of when the memory fragment is created.

        Returns:
            int: id
        """

    @property
    def metadata(self) -> Dict[str, Any]:
        """Return the metadata of the memory fragment.

        Returns:
            Dict[str, Any]: Metadata
        """
        return {}

    @property
    def importance(self) -> Optional[float]:
        """Return the importance of the memory fragment.

        It should be noted that importance only reflects the characters of the memory
        itself.

        Returns:
            Optional[float]: importance, None means the importance is not available.
        """
        return None

    @abstractmethod
    def update_importance(self, importance: float) -> Optional[float]:
        """Update the importance of the memory fragment.

        Args:
            importance(float): importance

        Returns:
            Optional[float]: importance
        """

    @property
    @abstractmethod
    def raw_observation(self) -> str:
        """Return the raw observation.

        Raw observation is the original observation data, it can be an observation from
         environment or an observation after executing an action.

        Returns:
            str: raw observation
        """

    @property
    def embeddings(self) -> Optional[List[float]]:
        """Return the embeddings of the memory fragment.

        Returns:
            Optional[List[float]]: embeddings
        """
        return None

    @abstractmethod
    def update_embeddings(self, embeddings: List[float]) -> None:
        """Update the embeddings of the memory fragment.

        Args:
            embeddings(List[float]): embeddings
        """

    def calculate_current_embeddings(
        self, embedding_func: Callable[[List[str]], List[List[float]]]
    ) -> List[float]:
        """Calculate the embeddings of the memory fragment.

        Args:
            embedding_func(Callable[[List[str]], List[List[float]]]): Function to
                compute embeddings

        Returns:
            List[float]: Embeddings of the memory fragment
        """
        raise NotImplementedError

    @property
    @abstractmethod
    def is_insight(self) -> bool:
        """Return whether the memory fragment is an insight.

        Returns:
            bool: whether the memory fragment is an insight.
        """

    @property
    @abstractmethod
    def last_accessed_time(self) -> Optional[datetime]:
        """Return the last accessed time of the memory fragment.

        Returns:
            Optional[datetime]: last accessed time
        """

    @abstractmethod
    def update_accessed_time(self, now: datetime) -> Optional[datetime]:
        """Update the last accessed time of the memory fragment.

        Args:
            now(datetime): The current time

        Returns:
            Optional[datetime]: The last accessed time
        """

    @abstractmethod
    def copy(self: T) -> T:
        """Copy the memory fragment."""

    def reduce(self, memory_fragments: List[T], **kwargs) -> T:
        """Reduce memory fragments to a single memory fragment.

        Args:
            memory_fragments(List[T]): Memory fragments

        Returns:
            T: The reduced memory fragment
        """
        obs = []
        for memory_fragment in memory_fragments:
            obs.append(memory_fragment.raw_observation)
        new_observation = ";".join(obs)
        return self.current_class.build_from(new_observation, **kwargs)  # type: ignore

    @property
    def current_class(self: T) -> Type[T]:
        """Return the current class."""
        return self.__class__


class InsightMemoryFragment(Generic[T]):
    """Insight memory fragment.

    Insight memory fragment is a memory fragment that contains insights.
    """

    def __init__(
        self,
        original_memory_fragment: Union[T, List[T]],
        insights: Union[List[T], List[str]],
    ):
        """Create an insight memory fragment.

        Insight is also a memory fragment.
        """
        if insights and isinstance(insights[0], str):
            mf = (
                original_memory_fragment[0]
                if isinstance(original_memory_fragment, list)
                else original_memory_fragment
            )
            insights = [
                mf.current_class.build_from(i, is_insight=True)
                for i in insights  # type: ignore # noqa
            ]
        self._original_memory_fragment = original_memory_fragment
        self._insights: List[T] = cast(List[T], insights)

    @property
    def original_memory_fragment(self) -> Union[T, List[T]]:
        """Return the original memory fragment."""
        return self._original_memory_fragment

    @property
    def insights(self) -> List[T]:
        """Return the insights."""
        return self._insights


class DiscardedMemoryFragments(Generic[T]):
    """Discarded memory fragments.

    Sometimes, we need to discard some memory fragments, there are following cases:
    1. Memory duplicated, the same/similar action is executed multiple times and the
        same/similar observation from environment is received.
    2. Memory overflow. The memory is full and the new memory fragment needs to be
    written.
    3. The memory fragment is not important enough.
    4. Simulation of forgetting mechanism.

    The discarded memory fragments may be transferred to another memory.
    """

    def __init__(
        self,
        discarded_memory_fragments: List[T],
        discarded_insights: Optional[List[InsightMemoryFragment[T]]] = None,
    ):
        """Create a discarded memory fragments."""
        if discarded_insights is None:
            discarded_insights = []
        self._discarded_memory_fragments = discarded_memory_fragments
        self._discarded_insights = discarded_insights

    @property
    def discarded_memory_fragments(self) -> List[T]:
        """Return the discarded memory fragments."""
        return self._discarded_memory_fragments

    @property
    def discarded_insights(self) -> List[InsightMemoryFragment[T]]:
        """Return the discarded insights."""
        return self._discarded_insights


class InsightExtractor(ABC, Generic[T]):
    """Insight extractor interface.

    Obtain high-level insights from memories.
    """

    @abstractmethod
    async def extract_insights(
        self,
        memory_fragment: T,
        llm_client: Optional[LLMClient] = None,
    ) -> InsightMemoryFragment[T]:
        """Extract insights from memory fragments.

        Args:
            memory_fragment(T): Memory fragment
            llm_client(Optional[LLMClient]): LLM client

        Returns:
            InsightMemoryFragment: The insights of the memory fragment.
        """


class ImportanceScorer(ABC, Generic[T]):
    """Importance scorer interface.

    Score the importance of memories.
    """

    @abstractmethod
    async def score_importance(
        self,
        memory_fragment: T,
        llm_client: Optional[LLMClient] = None,
    ) -> float:
        """Score the importance of memory fragment.

        Args:
            memory_fragment(T): Memory fragment.
            llm_client(Optional[LLMClient]): LLM client

        Returns:
            float: The importance of the memory fragment.
        """


@PublicAPI(stability="beta")
class Memory(ABC, Generic[T]):
    """Memory interface."""

    name: Optional[str] = None
    llm_client: Optional[LLMClient] = None
    importance_scorer: Optional[ImportanceScorer] = None
    insight_extractor: Optional[InsightExtractor] = None
    _real_memory_fragment_class: Optional[Type[T]] = None
    importance_weight: float = 0.15
    # The session id is used to identify the session of the agent.
    session_id: Optional[str] = None

    @mutable
    def initialize(
        self,
        name: Optional[str] = None,
        llm_client: Optional[LLMClient] = None,
        importance_scorer: Optional[ImportanceScorer] = None,
        insight_extractor: Optional[InsightExtractor] = None,
        real_memory_fragment_class: Optional[Type[T]] = None,
        session_id: Optional[str] = None,
    ) -> None:
        """Initialize memory.

        Some agent may need to initialize memory before using it.
        """
        self.name = name
        self.llm_client = llm_client
        self.importance_scorer = importance_scorer
        self.insight_extractor = insight_extractor
        self._real_memory_fragment_class = real_memory_fragment_class
        self.session_id = session_id

    @abstractmethod
    @immutable
    def structure_clone(self: M, now: Optional[datetime] = None) -> M:
        """Return a structure clone of the memory.

        Sometimes, we need to clone the structure of the memory, but not the content.

        There some cases:

        1. When we need to reset the memory, we can use this method to create a new
        one, and the new memory has the same structure as the old one.
        2. Create a new agent, the new agent has the same memory structure as the
        planner.

        Args:
            now(Optional[datetime]): The current time

        Returns:
            M: The structure clone of the memory

        """
        raise NotImplementedError

    @mutable
    def _copy_from(self, memory: "Memory") -> None:
        """Copy memory from another memory.

        Args:
            memory(Memory): Another memory
        """
        self.name = memory.name
        self.llm_client = memory.llm_client
        self.importance_scorer = memory.importance_scorer
        self.insight_extractor = memory.insight_extractor
        self._real_memory_fragment_class = memory._real_memory_fragment_class
        self.session_id = memory.session_id

    @abstractmethod
    @mutable
    async def write(
        self,
        memory_fragment: T,
        now: Optional[datetime] = None,
        op: WriteOperation = WriteOperation.ADD,
    ) -> Optional[DiscardedMemoryFragments[T]]:
        """Write a memory fragment to memory.

        Two situations need to be noted here:
        1. Memory duplicated, the same/similar action is executed multiple times and
            the same/similar observation from environment is received.

        2.Memory overflow. The memory is full and the new memory fragment needs to be
            written to memory, the common strategy is to discard some memory fragments.

        Args:
            memory_fragment(T): Memory fragment
            now(Optional[datetime]): The current time
            op(WriteOperation): Write operation

        Returns:
            Optional[DiscardedMemoryFragments]: The discarded memory fragments, None
                means no memory fragments are discarded.
        """

    @mutable
    async def write_batch(
        self, memory_fragments: List[T], now: Optional[datetime] = None
    ) -> Optional[DiscardedMemoryFragments[T]]:
        """Write a batch of memory fragments to memory.

        Args:
            memory_fragments(List[T]): Memory fragments
            now(Optional[datetime]): The current time

        Returns:
            Optional[DiscardedMemoryFragments]: The discarded memory fragments, None
                means no memory fragments are discarded.
        """
        discarded_memory_fragments = []
        discarded_insights = []
        for memory_fragment in memory_fragments:
            discarded_memory = await self.write(memory_fragment, now)
            if discarded_memory:
                if discarded_memory.discarded_memory_fragments:
                    discarded_memory_fragments.extend(
                        discarded_memory.discarded_memory_fragments
                    )
                if discarded_memory.discarded_insights:
                    discarded_insights.extend(discarded_memory.discarded_insights)
        return (
            DiscardedMemoryFragments(discarded_memory_fragments, discarded_insights)
            if discarded_memory_fragments
            else None
        )

    @abstractmethod
    @immutable
    async def read(
        self,
        observation: str,
        alpha: Optional[float] = None,
        beta: Optional[float] = None,
        gamma: Optional[float] = None,
    ) -> List[T]:
        r"""Read memory fragments by observation.

        Usually, there three commonly used criteria for information extraction, that is,
        the recency, relevance, and importance

        Memories that are more recent, relevant, and important are more likely to be
        extracted. Formally, we conclude the following equation from existing
        literature for memory information extraction:

        .. math::

            m^* = \arg\min_{m \in M} \alpha s^{\text{rec}}(q, m) + \\
            \beta s^{\text{rel}}(q, m) + \gamma s^{\text{imp}}(m), \tag{1}

        Args:
            observation(str): observation(Query)
            alpha(float, optional): Recency coefficient. Default is None.
            beta(float, optional): Relevance coefficient. Default is None.
            gamma(float, optional): Importance coefficient. Default is None.

        Returns:
            List[T]: memory fragments
        """

    @immutable
    async def reflect(self, memory_fragments: List[T]) -> List[T]:
        """Reflect memory fragments by observation.

        Args:
            memory_fragments(List[T]): memory fragments to be reflected.

        Returns:
            List[T]: memory fragments after reflection.
        """
        return memory_fragments

    @immutable
    async def handle_duplicated(
        self, memory_fragments: List[T], new_memory_fragments: List[T]
    ) -> List[T]:
        """Handle duplicated memory fragments.

        Args:
            memory_fragments(List[T]): Existing memory fragments
            new_memory_fragments(List[T]): New memory fragments

        Returns:
            List[T]: The new memory fragments after handling duplicated memory
                fragments.
        """
        return memory_fragments + new_memory_fragments

    @mutable
    async def handle_overflow(
        self, memory_fragments: List[T]
    ) -> Tuple[List[T], List[T]]:
        """Handle memory overflow.

        Args:
            memory_fragments(List[T]): Existing memory fragments

        Returns:
            Tuple[List[T], List[T]]: The memory fragments after handling overflow and
                the discarded memory fragments.
        """
        return memory_fragments, []

    @abstractmethod
    @mutable
    async def clear(self) -> List[T]:
        """Clear all memory fragments.

        Returns:
            List[T]: The all cleared memory fragments.
        """

    @immutable
    async def get_insights(
        self, memory_fragments: List[T]
    ) -> List[InsightMemoryFragment[T]]:
        """Get insights from memory fragments.

        Args:
            memory_fragments(List[T]): Memory fragments

        Returns:
            List[InsightMemoryFragment]: The insights of the memory fragments.
        """
        if not self.insight_extractor:
            return []
        # Obtain insights in parallel from memory fragments parallel
        tasks = []
        for memory_fragment in memory_fragments:
            tasks.append(
                self.insight_extractor.extract_insights(
                    memory_fragment, self.llm_client
                )
            )
        insights = await asyncio.gather(*tasks)
        result = []
        for insight in insights:
            if not insight:
                continue
            result.append(insight)
        if len(result) != len(insights):
            raise ValueError(
                "The number of insights is not equal to the number of memory fragments."
            )
        return result

    @immutable
    async def score_memory_importance(self, memory_fragments: List[T]) -> List[float]:
        """Score the importance of memory fragments.

        Args:
            memory_fragments(List[T]): Memory fragments

        Returns:
            List[float]: The importance of memory fragments.
        """
        if not self.importance_scorer:
            return [5 * self.importance_weight for _ in memory_fragments]
        tasks = []
        for memory_fragment in memory_fragments:
            tasks.append(
                self.importance_scorer.score_importance(
                    memory_fragment, self.llm_client
                )
            )
        result = []
        for importance in await asyncio.gather(*tasks):
            real_score = importance * self.importance_weight
            result.append(real_score)
        return result

    @property
    @immutable
    def real_memory_fragment_class(self) -> Type[T]:
        """Return the real memory fragment class."""
        if not self._real_memory_fragment_class:
            raise ValueError("The real memory fragment class is not set.")
        return self._real_memory_fragment_class


class SensoryMemory(Memory, Generic[T]):
    """Sensory memory."""

    importance_weight: float = 0.9
    threshold_to_short_term: float = 0.1

    def __init__(self, buffer_size: int = 0):
        """Create a sensory memory."""
        self._buffer_size = buffer_size
        self._fragments: List[T] = []
        self._lock = asyncio.Lock()

    def structure_clone(
        self: "SensoryMemory[T]", now: Optional[datetime] = None
    ) -> "SensoryMemory[T]":
        """Return a structure clone of the memory."""
        m: SensoryMemory[T] = SensoryMemory(buffer_size=self._buffer_size)
        m._copy_from(self)
        return m

    @mutable
    async def write(
        self,
        memory_fragment: T,
        now: Optional[datetime] = None,
        op: WriteOperation = WriteOperation.ADD,
    ) -> Optional[DiscardedMemoryFragments[T]]:
        """Write a memory fragment to sensory memory."""
        fragments = await self.handle_duplicated(self._fragments, [memory_fragment])
        discarded_fragments: List[T] = []
        if len(fragments) > self._buffer_size:
            fragments, discarded_fragments = await self.handle_overflow(fragments)

        async with self._lock:
            await self.clear()
            self._fragments = fragments
        if not discarded_fragments:
            return None
        return DiscardedMemoryFragments(discarded_fragments, [])

    @immutable
    async def read(
        self,
        observation: str,
        alpha: Optional[float] = None,
        beta: Optional[float] = None,
        gamma: Optional[float] = None,
    ) -> List[T]:
        """Read memory fragments by observation."""
        return self._fragments

    @mutable
    async def handle_overflow(
        self, memory_fragments: List[T]
    ) -> Tuple[List[T], List[T]]:
        """Handle memory overflow.

        For sensory memory, the overflow strategy is to transfer all memory fragments
        to short-term memory.

        Args:
            memory_fragments(List[T]): Existing memory fragments

        Returns:
            Tuple[List[T], List[T]]: The memory fragments after handling overflow and
                the discarded memory fragments, the discarded memory fragments should
                be transferred to short-term memory.
        """
        scores = await self.score_memory_importance(memory_fragments)
        result = []
        for i, memory in enumerate(memory_fragments):
            if scores[i] >= self.threshold_to_short_term:
                memory.update_importance(scores[i])
                result.append(memory)
        return [], result

    @mutable
    async def clear(self) -> List[T]:
        """Clear all memory fragments."""
        # async with self._lock:
        fragments = self._fragments
        self._fragments = []
        return fragments


class ShortTermMemory(Memory, Generic[T]):
    """Short term memory.

    All memories are stored in computer memory.
    """

    def __init__(self, buffer_size: int = 5):
        """Create a short-term memory."""
        self._buffer_size = buffer_size
        self._fragments: List[T] = []
        self._lock = asyncio.Lock()

    def structure_clone(
        self: "ShortTermMemory[T]", now: Optional[datetime] = None
    ) -> "ShortTermMemory[T]":
        """Return a structure clone of the memory."""
        m: ShortTermMemory[T] = ShortTermMemory(
            buffer_size=self._buffer_size,
        )
        m._copy_from(self)
        return m

    @mutable
    async def write(
        self,
        memory_fragment: T,
        now: Optional[datetime] = None,
        op: WriteOperation = WriteOperation.ADD,
    ) -> Optional[DiscardedMemoryFragments[T]]:
        """Write a memory fragment to short-term memory.

        Args:
            memory_fragment(T): New memory fragment
            now(Optional[datetime]): The current time
            op(WriteOperation): Write operation

        Returns:
            Optional[DiscardedMemoryFragments]: The discarded memory fragments, None
                means no memory fragments are discarded. The discarded memory fragments
                should be transferred and stored in long-term memory.
        """
        fragments = await self.handle_duplicated(self._fragments, [memory_fragment])

        async with self._lock:
            await self.clear()
            self._fragments = fragments
            discarded_memories = await self.transfer_to_long_term(memory_fragment)
            fragments, discarded_fragments = await self.handle_overflow(self._fragments)
            self._fragments = fragments
            return discarded_memories

    @immutable
    async def read(
        self,
        observation: str,
        alpha: Optional[float] = None,
        beta: Optional[float] = None,
        gamma: Optional[float] = None,
    ) -> List[T]:
        """Read memory fragments by observation."""
        return self._fragments

    @mutable
    async def transfer_to_long_term(
        self, memory_fragment: T
    ) -> Optional[DiscardedMemoryFragments[T]]:
        """Transfer the oldest memories to long-term memory.

        This is a very simple strategy, just transfer the oldest memories to long-term
        memory.
        """
        if len(self._fragments) > self._buffer_size:
            overflow_cnt = len(self._fragments) - self._buffer_size
            # Just keep the most recent memories in short-term memory
            self._fragments = self._fragments[overflow_cnt:]
            # Transfer the oldest memories to long-term memory
            overflow_fragments = self._fragments[:overflow_cnt]
            insights = await self.get_insights(overflow_fragments)
            return DiscardedMemoryFragments(overflow_fragments, insights)
        else:
            return None

    @mutable
    async def clear(self) -> List[T]:
        """Clear all memory fragments."""
        # async with self._lock:
        fragments = self._fragments
        self._fragments = []
        return fragments

    @property
    @immutable
    def short_term_memories(self) -> List[T]:
        """Return short-term memories."""
        return self._fragments
